#----------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F

import sys
# sys.path.append('/public/chenhaozhe/4_model_robust_hash/code_clean/shrinkbench_CHZ_new')
# import models

import os
os.environ['WEIGHTSPATH'] = '/public/chenhaozhe/4_model_robust_hash/code_clean/shrinkbench_CHZ_new/pretrained/'
from IPython.display import clear_output
import pdb
import numpy as np

import math


sys.path.append('/data/liuruiheng/MM_code/Dynamic-convolution-Pytorch')
# from dynamic_conv import Dynamic_conv2d

import torch.nn.utils.prune as prune
#----------------------------------------------------------------------------

torch.cuda.empty_cache()

import torch

# Define the function based on the given equation
def phi_i(x, w_i, beta, alpha, m, lambda_value):
    """
    Implement the function as shown in the uploaded equation.

    :param y: The input tensor y with shape (n_samples, m)
    :param w_i: The weight tensor w_i with shape (n_samples, m)
    :param beta: The beta parameter (scalar or tensor)
    :param alpha: The alpha parameter (scalar or tensor)
    :param m: The number of elements to sum over
    :return: The computed phi_i tensor
    """
    # Calculating the term (1 + λ*w_i)^β for all j


    y = x.cpu()*(1+torch.mul(lambda_value, w_i.cpu()))

    term = (1 + torch.mul(lambda_value, w_i.cpu())) ** beta.cpu()

    # Hypothetical adjustment (example only):
    # If `term` needs to be broadcasted or reshaped, for example:
    term_adjusted = term.cuda().view(1, -1)  # Adjusting `term` shape to be compatible

    # Assuming `alpha` is compatible or a scalar, the operation then looks like:
    phi_i_val = torch.sum(y.cuda() ** beta * ((term_adjusted - 1) / alpha - term_adjusted), axis=1)

    # # Calculating the sum for j=1 to m
    # phi_i_val = torch.sum(y ** beta * ((term.cuda() - 1) / alpha - term.cuda()), axis=1)

    return phi_i_val

# # Example usage:
# # Set random seed for reproducibility
# torch.manual_seed(0)

# # Assume some example values for y, w_i, beta, alpha, and m
# n_samples = 5  # number of samples, i.e., batch size
# m = 3          # number of elements to sum over

# # Randomly generated example tensors for y and w_i
# y = torch.randn(n_samples, m)
# w_i = torch.randn(n_samples, m)

# # Example values for alpha and beta
# beta = torch.tensor(2.0)  # just an example value, should be defined according to your specific problem
# alpha = torch.tensor(1.5) # just an example value, should be defined according to your specific problem

# # Calculate phi_i using the function
# phi_i_values = phi_i(y, w_i, beta, alpha, m)

# # Print the results
# phi_i_values

# Define the function based on the new given equation
def T_i(w_i, beta, m, lambda_value):
    """
    Implement the function as shown in the uploaded equation for T^(i).

    :param w_i: The weight tensor w_i with shape (n_samples, m)
    :param beta: The beta parameter (scalar or tensor)
    :param m: The number of elements to sum over
    :return: The computed T_i tensor
    """
    # Calculating the term (1 + λw_i)^β for all j
    term_pow_beta = (1 + torch.mul(lambda_value, w_i.cpu())) ** beta.cpu()

    # pdb.set_trace()

    # Calculating the sum for j=1 to m
    # T_i_val = torch.sum((term_pow_beta.cuda() - 1) / term_pow_beta.cuda(), axis=1)
    T_i_val = torch.sum((term_pow_beta.cuda() - 1) / term_pow_beta.cuda())

    return T_i_val

# # Example usage:
# # Assume some example values for w_i, beta, and m
# n_samples = 5  # number of samples, i.e., batch size
# m = 3          # number of elements to sum over

# # Randomly generated example tensor for w_i
# w_i = torch.randn(n_samples, m)

# # Example value for beta
# beta = torch.tensor(2.0)  # just an example value, should be defined according to your specific problem

# # Calculate T_i using the function
# T_i_values = T_i(w_i, beta, m)

# # Print the results
# T_i_values

# Define the function to compute h(i) based on the comparison between phi(i)(y) and T^(i)
def compute_h_i(phi_i_vals, T_i_vals):
    """
    Compute h(i) based on the condition given in the equation.

    :param phi_i_vals: Tensor containing values of phi^(i)(y)
    :param T_i_vals: Tensor containing values of T^(i)
    :return: Tensor containing values of h^(i)
    """
    h_i_vals = (phi_i_vals >= T_i_vals).int()  # Convert the boolean condition to integers (0 or 1)
    return h_i_vals

# # Example usage:
# # Assume some example values for phi_i_vals and T_i_vals
# # (For simplicity, we use the previously calculated values)

# # Previously computed values for phi_i and T_i
# phi_i_vals = torch.tensor([-9.7649, -4.5198, -3.6372, -1.4624, -2.4515])
# T_i_vals = torch.tensor([   -8.3760,     1.2758, -1135.1332,  -147.7731,  -262.2310])

# # Compute h(i) values
# h_i_values = compute_h_i(phi_i_vals, T_i_vals)

# # Print the results
# h_i_values


class SigmoidTimesTenActivation(nn.Module):
    def forward(self, x):
        return 10 * torch.sigmoid(x)


# pre-processing
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.selu = nn.SELU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.selu(x)
        return x

class FullyConnectedSELU(nn.Module):
    def __init__(self, in_features, out_features):
        super(FullyConnectedSELU, self).__init__()
        self.fc = nn.Linear(in_features, out_features)
        self.selu = nn.SELU()

    def forward(self, x):
        x = self.fc(torch.flatten(x))
        x = self.selu(x)
        return x

class DilatedConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding, dilation):
        super(DilatedConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=dilation, dilation=dilation) 
        self.bn = nn.BatchNorm2d(out_channels) # BN
        self.selu = nn.SELU()

    def forward(self, x):
        try:
            x = self.conv(x)
        except:
            pdb.set_trace()
        x = self.bn(x)
        x = self.selu(x)
        return x

class ParallelDilatedConvLayers(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation_list):
        super(ParallelDilatedConvLayers, self).__init__()
        self.blocks = nn.ModuleList()

        for dilation in dilation_list:
            block = DilatedConvBlock(in_channels, out_channels, kernel_size, padding=dilation, dilation=dilation)
            self.blocks.append(block)

    def forward(self, x):
        outputs = []
        for block in self.blocks:
            out = block(x)
            outputs.append(out)
        return torch.cat(outputs, 1)

# Now, the convolutional and fully connected layers should have sizes adjusted to 128x128
# 预处理 归一化部分代码
def calculate_lambda_prime(lambda_value):
    # pdb.set_trace()
    # y_max = np.max(lambda_value)
    # y_min = np.min(lambda_value)
    eps = 1e-6
    y_max = torch.max(lambda_value)
    y_min = torch.min(lambda_value)
    # 计算 (λ - y_min) / (y_max - y_min)
    # normalized_value = (lambda_value - y_min) / (y_max - y_min)
    normalized_value = (lambda_value - y_min) / (y_max - y_min + eps)
    # 计算 10 * log10((normalized_value * 9) + 1)
    # pdb.set_trace()
    lambda_prime = 10 * torch.log10((normalized_value * 9) + 1) 
    return lambda_prime

def propocess_conv_fun(conv_layer_weight):
    # 1. 改成函数形式
    # 创建一个示例的卷积层参数（假设参数是128x128）
    # conv_layer = nn.Conv2d(3, 128, kernel_size=3, padding=1)

    # 获取卷积层的权重参数
    weights = conv_layer_weight
    # conv_layer.weight.data

    # 定义每个小块的大小（4x4）
    block_size = 4

    # 获取权重参数的维度信息
    num_channels, num_filters = weights.size()

    # 计算每个小块的数量
    num_blocks = (num_filters * 16) // (block_size * block_size)

    # 初始化一个存储结果的张量
    normalized_weights = torch.empty_like(weights)

    # 将权重参数分成小块，归一化每个小块，然后重新组合
    for i in range(num_blocks):
        start_channel = i * (block_size * block_size) // 16
        end_channel = (i + 1) * (block_size * block_size) // 16
        block_weights = weights[start_channel:end_channel, :]

        # 归一化每个小块
        # 2. 把归一化替换掉：
        block_weights = calculate_lambda_prime(block_weights)

        # 将归一化后的小块放回结果张量的相应位置
        normalized_weights[start_channel:end_channel, :] = block_weights

    # 现在，normalized_weights 包含了归一化后的参数
    # 你可以将它们重新设置为卷积层的权重
    return normalized_weights

def propocess_fc_fun(fc_layer_weight_data):
    # 获取卷积层的权重参数
    weights = fc_layer_weight_data

    # 定义每个小块的大小（4x4）
    block_size = 4

    # 获取权重参数的维度信息
    num_channels, num_filters = weights.size()

    # 计算每个小块的数量
    num_blocks = (num_filters * 16) // (block_size * block_size)

    # 初始化一个存储结果的张量
    normalized_weights = torch.empty_like(weights)

    # 将权重参数分成小块，归一化每个小块，然后重新组合
    for i in range(num_blocks):
        start_channel = i * (block_size * block_size) // 16
        end_channel = (i + 1) * (block_size * block_size) // 16
        block_weights = weights[start_channel:end_channel, :]

        # 归一化每个小块
        # 2. 把归一化替换掉：
        block_weights = calculate_lambda_prime(block_weights)

        # 将归一化后的小块放回结果张量的相应位置
        normalized_weights[start_channel:end_channel, :] = block_weights

    # 现在，normalized_weights 包含了归一化后的参数
    # 你可以将它们重新设置为卷积层的权重
    # fc_layer.weight.data = normalized_weights
    return normalized_weights
    #---------------------------------------------------------------------

def propocess_fun(model):
    target_size = (128, 128)  # Target size for adaptive average pooling

    # Iterate through the model's layers and apply adaptive average pooling
    beta_c = []
    beta_f = []
    n = 0
    l = 0

    # for sub_model_name, sub_model in model.named_children():
    #     for name, layer in sub_model.named_children():
    #         pdb.set_trace()
    #         if isinstance(layer, nn.Conv2d):
    #             # layer.weight = nn.Parameter(nn.functional.adaptive_avg_pool2d(layer.weight.view(1, 1, layer.weight.shape[0]*layer.weight.shape[1], layer.weight.shape[2]*layer.weight.shape[3]), target_size).squeeze())
    #             layer_weight = nn.functional.adaptive_avg_pool2d(layer.weight.view(1, 1, layer.weight.shape[0]*layer.weight.shape[1], layer.weight.shape[2]*layer.weight.shape[3]), target_size).squeeze()
    #             # For convolutional layers, apply adaptive average pooling to the weight tensor
    #             conv_weights = propocess_conv_fun(layer_weight)
    #             # 怎么排列成128*128:
    #             # print(conv_weights.size())
    #             beta_c.append(conv_weights)
    #             n = n+1
    #         elif isinstance(layer, nn.Linear):
    #             # num_features = layer.in_features
    #             num_features = layer.in_features
    #             # layer.weight = nn.Parameter(nn.functional.adaptive_avg_pool2d(layer.weight.view(1, -1, num_features), target_size).squeeze())
    #             layer_weight = nn.functional.adaptive_avg_pool2d(layer.weight.view(1, -1, num_features), target_size).squeeze()
 
    #             # For fully connected layers, reshape and then apply adaptive average pooling
    #             fc_weights = propocess_fc_fun(layer_weight)
    #             beta_f.append(fc_weights)
    #             l = l+1

    for sub_model_name, sub_model in model.named_modules():
        for name, layer in sub_model.named_modules():
            # pdb.set_trace()
            if isinstance(layer, nn.Conv2d):
                layer_pruned = prune.l1_unstructured(layer,  'weight', amount=0.8)
                # layer.weight = nn.Parameter(nn.functional.adaptive_avg_pool2d(layer.weight.view(1, 1, layer.weight.shape[0]*layer.weight.shape[1], layer.weight.shape[2]*layer.weight.shape[3]), target_size).squeeze())
                layer_weight = nn.functional.adaptive_avg_pool2d(layer_pruned.weight.view(1, 1, layer.weight.shape[0]*layer.weight.shape[1], layer.weight.shape[2]*layer.weight.shape[3]), target_size).squeeze()
                # For convolutional layers, apply adaptive average pooling to the weight tensor
                conv_weights = propocess_conv_fun(layer_weight)
                # 怎么排列成128*128:
                # print(conv_weights.size())
                beta_c.append(conv_weights)
                n = n+1
            elif isinstance(layer, nn.Linear):
                # num_features = layer.in_features
                num_features = layer.in_features
                layer_pruned = prune.l1_unstructured(layer,  'weight', amount=0.8)
                # layer.weight = nn.Parameter(nn.functional.adaptive_avg_pool2d(layer.weight.view(1, -1, num_features), target_size).squeeze())
                layer_weight = nn.functional.adaptive_avg_pool2d(layer_pruned.weight.view(1, -1, num_features), target_size).squeeze()
 
                # For fully connected layers, reshape and then apply adaptive average pooling
                fc_weights = propocess_fc_fun(layer_weight)
                beta_f.append(fc_weights)
                l = l+1

    # # pdb.set_trace()
    # for name, layer in model.named_modules(): # 遍历模型的每个层
    #     pdb.set_trace()
    #     if isinstance(layer, nn.Conv2d):
    #         # layer.weight = nn.Parameter(nn.functional.adaptive_avg_pool2d(layer.weight.view(1, 1, layer.weight.shape[0]*layer.weight.shape[1], layer.weight.shape[2]*layer.weight.shape[3]), target_size).squeeze())
    #         layer_weight = nn.functional.adaptive_avg_pool2d(layer.weight.view(1, 1, layer.weight.shape[0]*layer.weight.shape[1], layer.weight.shape[2]*layer.weight.shape[3]), target_size).squeeze()
    #         # For convolutional layers, apply adaptive average pooling to the weight tensor
    #         conv_weights = propocess_conv_fun(layer_weight)
    #         # 怎么排列成128*128:
    #         # print(conv_weights.size())
    #         beta_c.append(conv_weights)
    #         n = n+1
    #     elif isinstance(layer, nn.Linear):
    #         # num_features = layer.in_features
    #         num_features = layer.in_features
    #         # layer.weight = nn.Parameter(nn.functional.adaptive_avg_pool2d(layer.weight.view(1, -1, num_features), target_size).squeeze())
    #         layer_weight = nn.functional.adaptive_avg_pool2d(layer.weight.view(1, -1, num_features), target_size).squeeze()

    #         # For fully connected layers, reshape and then apply adaptive average pooling
    #         fc_weights = propocess_fc_fun(layer_weight)
    #         beta_f.append(fc_weights)
    #         l = l+1

    if n != 0:
        beta_c = torch.stack(beta_c, dim=0)
    else:
        beta_c = torch.ones((32, 128, 128))

    if l != 0:
        beta_f = torch.stack(beta_f, dim=0)
    else:
        beta_f = torch.ones((32, 128, 128))

    return beta_c, beta_f, n, l

class DualBranchNet(nn.Module):
    def __init__(self):
        super(DualBranchNet, self).__init__()
        
        # First branch
        self.branch1 = nn.Sequential(
            # Dynamic_conv2d(in_planes=49, out_planes=32, kernel_size=1),
            nn.MaxPool2d((4,4)),
            ConvBlock(in_channels=32, out_channels=32, kernel_size=3),
            ConvBlock(in_channels=32, out_channels=32, kernel_size=3),
            ConvBlock(in_channels=32, out_channels=32, kernel_size=3),
            ConvBlock(in_channels=32, out_channels=32, kernel_size=3),
            ConvBlock(in_channels=32, out_channels=32, kernel_size=3),
            ConvBlock(in_channels=32, out_channels=32, kernel_size=3),
            ConvBlock(in_channels=32, out_channels=32, kernel_size=3),
            nn.AdaptiveMaxPool2d((16, 16)),
            FullyConnectedSELU(32*16*16, 2048),
            FullyConnectedSELU(2048, 50)
        )
        
        # Second branch
        self.branch2 = nn.Sequential(
            # Dynamic_conv2d(in_planes=2, out_planes=32, kernel_size=1),
            nn.MaxPool2d((4,4)),
            ConvBlock(in_channels=32, out_channels=32, kernel_size=3),
            ConvBlock(in_channels=32, out_channels=32, kernel_size=3),
            ParallelDilatedConvLayers(in_channels=32, out_channels=32, kernel_size=3, dilation_list=[3, 5, 7, 9]),
            ParallelDilatedConvLayers(in_channels=128, out_channels=32, kernel_size=3, dilation_list=[3, 5, 7, 9]),
            ParallelDilatedConvLayers(in_channels=128, out_channels=32, kernel_size=3, dilation_list=[3, 5, 7, 9]),
            ParallelDilatedConvLayers(in_channels=128, out_channels=32, kernel_size=3, dilation_list=[3, 5, 7, 9]),
            ConvBlock(in_channels=32*4, out_channels=32, kernel_size=3),
            nn.AdaptiveMaxPool2d((16, 16)),
            FullyConnectedSELU(32*16*16, 2048),
            FullyConnectedSELU(2048, 50)
        )

        # Merge the branches
        self.merge = FullyConnectedSELU(100, 50)

        # lambda_value = 0.75
        # m = 225

        # self.vam_model = VAM(input_dim=(100,50), latent_dim=2, output_dim=(100,50))
        
        # Output layer
        self.output = SigmoidTimesTenActivation()

    def forward(self, x1, x2):
        # Forward pass for the first branch
        out1 = self.branch1(x1)
        
        # Forward pass for the second branch
        out2 = self.branch2(x2)

        # Concatenate the branch outputs
        merged = torch.cat((out1, out2))
        
        # Merge the branches
        merged = self.merge(merged)

        output = self.output(merged)
        
        return output




